import numpy as np
import pdb
import os
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker
import seaborn as sns
import torch
from itertools import chain, combinations
from algos.ardqn.policies import DQNPolicy
import random
#from sklearn.preprocessing import MinMaxScaler
#from sklearn import preprocessing

import warnings
warnings.filterwarnings("error")

def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

def collect_data(env, policy, num_trajectory, truncated_horizon):
    paths = []
    num_samples = 0
    total_reward = 0.0
    for i_trajectory in range(num_trajectory):
        path = {}
        path['obs'] = []
        path['nobs'] = []
        path['acts'] = []
        path['rews'] = []
        path['avg_backlog'] = []
        path['avg_backlog_change'] = []
        state, _ = env.reset()
        sasr = []
        for i_t in range(truncated_horizon):
            action = policy(state, i_t)
            next_state, reward, done, _, info = env.step(action)
            path['obs'].append(state)
            path['acts'].append(action)
            path['rews'].append(reward)
            path['nobs'].append(next_state)
            path['avg_backlog'].append(info['backlog'])
            #path['avg_backlog_change'].append(info['backlog_change'])
            #sasr.append((state, action, next_state, reward))
            total_reward += reward
            state = next_state
            if done:
                break
        paths.append(path)
        num_samples += len(paths[-1]['obs'])
    return paths, total_reward / num_samples#(num_trajectory * truncated_horizon)

def get_MSE(true_val, pred_vals):
    sq_error = np.square(np.array(true_val) - np.array(pred_vals))
    res = get_CI(sq_error)
    return res

# statistics/visualization related
def get_CI(data, confidence = 0.95):

    if (np.array(data) == None).all():
        return {}
    if confidence == 0.95:
        z = 1.96
    elif confidence == 0.99:
        z = 2.576
    stats = {}
    n = len(data)
    mean = np.mean(data)
    std = np.std(data)
    err = z * (std / np.sqrt(n))
    lower = mean - z * (std / np.sqrt(n))
    upper = mean + z * (std / np.sqrt(n))
    stats = {
        'mean': mean,
        'std': std,
        'lower': lower,
        'upper': upper,
        'err': err,
        'max': np.max(data),
        'min': np.min(data)
    }
    return stats

def plot_heatmap(pi, name, transformation = None, within_callback = False, env = None):
    bounds = 101
    #for cons in [[0, 1], [1, 0], [1,1]]:
    types = ['both', 'only1', 'only2']
    #for idx, cons in enumerate([[0, 1, 0, 1]]):#, [0, 1, 1, 0], [1, 0, 0, 1]]):
    for idx, cons in enumerate([[1, 1], [1, 0], [0,1]]):
        ma = np.zeros((bounds, bounds))

        for q1 in range(0, bounds):
            for q2 in range(0, bounds):
                lens = np.array([q1, q2])
                st = np.concatenate((lens, np.array(cons)))
                st = st.tolist()

                if within_callback:
                    st = env.env_method('transform_state', st)[0]
                    st = st.tolist()

                obs = torch.tensor([st])

                if hasattr(pi, 'policy') and isinstance(pi.policy, DQNPolicy):
                    if pi.policy.boltzmann_exp:
                        dis = pi.policy.get_distribution(obs)
                        probs = dis.probs
                        probs_np = probs.detach().numpy()[0]
                        prob = probs_np[0]
                    else:
                        act = pi.policy.q_net.predict(obs)[0]
                        # if action is queue 1, then prob 1
                        prob = (act[0] == 0).astype(int)
                else:
                    if within_callback:
                        dis = pi.policy.get_distribution(obs)
                    else:
                        dis = pi.pi.policy.get_distribution(obs)
                    probs = dis.distribution.probs
                    probs_np = probs.detach().numpy()[0]
                    prob = probs_np[0] # prob of serving queue 1
                ma[q1, q2] = prob 

        ax = sns.heatmap(ma, linewidth=0.5)
        ax.invert_yaxis()
        plt.title('P(serving Q1)')
        plt.ylabel('Q1 length')
        plt.xlabel('Q2 length')
        #plt.imshow(ma, cmap = 'hot')
        #plt.savefig('{}_{}_{}_heat.pdf'.format(name, cons[0], cons[1]))
        plt.savefig('{}_{}_heat.pdf'.format(name, types[idx]))

        plt.close()

def powerset(iterable):
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1))

def symlog(x, base = 'e'):
    if base == 'e':
        return np.sign(x) * np.log(np.abs(x) + 1)
    elif base == '10':
        return np.sign(x) * np.log10(np.abs(x) + 1)

def symsqrt(x):
    return np.sign(x) * (np.sqrt(np.abs(x) + 1) - 1)

def sigmoid(x):
    return 1. / (1. + np.exp(-x))

def tanh(x):
    return np.tanh(x)

def ewma_vectorized(data, alpha, offset=None, dtype=None, order='C', out=None):
    """
    Calculates the exponential moving average over a vector.
    Will fail for large inputs.
    :param data: Input data
    :param alpha: scalar float in range (0,1)
        The alpha parameter for the moving average.
    :param offset: optional
        The offset for the moving average, scalar. Defaults to data[0].
    :param dtype: optional
        Data type used for calculations. Defaults to float64 unless
        data.dtype is float32, then it will use float32.
    :param order: {'C', 'F', 'A'}, optional
        Order to use when flattening the data. Defaults to 'C'.
    :param out: ndarray, or None, optional
        A location into which the result is stored. If provided, it must have
        the same shape as the input. If not provided or `None`,
        a freshly-allocated array is returned.
    """
    data = np.array(data, copy=False)

    if dtype is None:
        if data.dtype == np.float32:
            dtype = np.float32
        else:
            dtype = np.float64
    else:
        dtype = np.dtype(dtype)

    if data.ndim > 1:
        # flatten input
        data = data.reshape(-1, order)

    if out is None:
        out = np.empty_like(data, dtype=dtype)
    else:
        assert out.shape == data.shape
        assert out.dtype == dtype

    if data.size < 1:
        # empty input, return empty array
        return out

    if offset is None:
        offset = data[0]

    alpha = np.array(alpha, copy=False).astype(dtype, copy=False)

    # scaling_factors -> 0 as len(data) gets large
    # this leads to divide-by-zeros below
    scaling_factors = np.power(1. - alpha, np.arange(data.size + 1, dtype=dtype),
                               dtype=dtype)
    # create cumulative sum array
    np.multiply(data, (alpha * scaling_factors[-2]) / scaling_factors[:-1],
                dtype=dtype, out=out)
    np.cumsum(out, dtype=dtype, out=out)

    pdb.set_trace()
    # cumsums / scaling
    out /= scaling_factors[-2::-1]

    if offset != 0:
        offset = np.array(offset, copy=False).astype(dtype, copy=False)
        # add offsets
        out += offset * scaling_factors[1:]

    return out

def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

def custom_exp_ma(data, alpha = 0.1):
    data = np.array(data)

